#!/usr/bin/python3
# -*- coding:utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================

import os
import sys
import numpy as np

LOSS = 1e-3 # 容忍偏差，一般fp16要求绝对误差和相对误差均不超过千分之一
MINIMUM = 10e-10


def verify_result(real_result, golden):
    cdtype = os.getenv('COMPUTE_TYPE')
    if cdtype == 'float16':
        dtype = np.float16
    elif cdtype == 'float32':
        dtype = np.float32
    else:
        dtype = np.int32
    real_result = np.fromfile(real_result, dtype=dtype).astype(np.int32) # 从bin文件读取实际运算结果
    golden = np.fromfile(golden, dtype=dtype).astype(np.int32) # 从bin文件读取预期运算结果
    if len(real_result) != len(golden):
        print(len(real_result), len(golden))
        print("[ERROR] result error")
        return False
    for i, _ in enumerate(real_result):
        if real_result[i] != golden[i]:
            print("[ERROR] result error %d except %d but %d", i, golden[i], real_result[i])
            return False
    print("test pass")
    return True

if __name__ == '__main__':
    verify_result(sys.argv[1], sys.argv[2])
